{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimizing Development for EvoX via PyTorch Advanced Techniques"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Optimization Support for Functions in PyTorch\n",
    "\n",
    "PyTorch provides fundamental optimization support for functions, primarily through vectorizing map (vmap) operations and Just-In-Time (JIT) compilation. These techniques enable efficient batch processing and enhance execution performance, respectively. Introductions of these optimizations are provided in the following sections."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batch Processing Support through Vectorizing Map in PyTorch\n",
    "\n",
    "Vectorizing map, implemented in PyTorch as [`torch.vmap`](https://pytorch.org/docs/stable/generated/torch.vmap.html), is a powerful tool that takes a callable function and returns a batched version of it. According to specified strategy, this new function vectorizes the operations of the original one, which facilitates efficient batch processing. In EvoX, for example, this feature plays a crucial role in hyperparameter optimization (HPO)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  1,  4,  9, 16, 25, 36, 49, 64],\n",
       "        [ 0,  1,  4,  9, 16, 25, 36, 49, 64],\n",
       "        [ 0,  1,  4,  9, 16, 25, 36, 49, 64]])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):\n",
    "    return pop_x * y\n",
    "\n",
    "\n",
    "batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))\n",
    "\n",
    "population_size = 3\n",
    "individual_vector_size = 9\n",
    "pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)\n",
    "y = torch.arange(individual_vector_size)\n",
    "\n",
    "batched_dummy_evaluation(pop_x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Just-In-Time (JIT) Support in PyTorch\n",
    "\n",
    "In PyTorch, [`torch.jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and [`torch.jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.script.html) provide two distinct types of JIT tools, supporting function performance optimization through tracing and scripting, respectively.\n",
    "\n",
    "Based on the tracing strategy, the `torch.jit.trace` method offers higher parsing speed and broader compatibility, such as with `torch.vmap` operations. Although it provides excellent support for simple functions, it is not suitable for complex tasks involving dynamic if-else branches and loop control flows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "\n",
    "@functools.partial(torch.vmap, in_dims=(0, None))\n",
    "def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):\n",
    "    return x.sum() + y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the example below, the traced `vmap` function successfully returns the correct code representation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def vmap_sample_func(x: Tensor,\n",
      "    y: Tensor) -> Tensor:\n",
      "  _0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)\n",
      "  return _0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "traced_vmap_func = torch.jit.trace(vmap_sample_func, example_inputs=(pop_x, y))\n",
    "print(traced_vmap_func.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, dynamic python control-flow cannot be traced correctly and a warning will be raised:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def dynamic_control_flow(pop_x: Tensor,\n",
      "    y: Tensor) -> Tensor:\n",
      "  y0 = torch.flatten(y)\n",
      "  _0 = torch.slice(torch.unsqueeze(y0, 0), 1, 0, 9223372036854775807)\n",
      "  return torch.mul(pop_x, _0)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\skk77\\AppData\\Local\\Temp\\ipykernel_19564\\349648791.py:2: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if y.flatten()[0] > 0:\n"
     ]
    }
   ],
   "source": [
    "def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):\n",
    "    if y.flatten()[0] > 0:\n",
    "        return pop_x + y[None, :]\n",
    "    else:\n",
    "        return pop_x * y[None, :]\n",
    "\n",
    "\n",
    "traced_dynamic_control_flow_func = torch.jit.trace(dynamic_control_flow, example_inputs=(pop_x, y))\n",
    "print(traced_dynamic_control_flow_func.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, the `torch.jit.script` method, which adopts a scripting strategy, is better suited for complex tasks that involve dynamic control flows but has limited compatibility."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, the same `vmap_sample_func` function, after being scripted, returns an **incorrect** code representation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def vmap_sample_func(x: Tensor,\n",
      "    y: Tensor) -> Tensor:\n",
      "  return torch.add(torch.sum(x), y)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "scripted_vmap_func = torch.jit.script(vmap_sample_func)\n",
    "print(scripted_vmap_func.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yet, it can correctly deal with complex dynamic python control flow:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def dynamic_control_flow(pop_x: Tensor,\n",
      "    y: Tensor) -> Tensor:\n",
      "  _0 = torch.gt(torch.select(torch.flatten(y), 0, 0), 0)\n",
      "  if bool(_0):\n",
      "    _2 = torch.slice(torch.unsqueeze(y, 0), 1)\n",
      "    _1 = torch.add(pop_x, _2)\n",
      "  else:\n",
      "    _3 = torch.slice(torch.unsqueeze(y, 0), 1)\n",
      "    _1 = torch.mul(pop_x, _3)\n",
      "  return _1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def dynamic_control_flow(pop_x: torch.Tensor, y: torch.Tensor):\n",
    "    if y.flatten()[0] > 0:\n",
    "        return pop_x + y[None, :]\n",
    "    else:\n",
    "        return pop_x * y[None, :]\n",
    "\n",
    "\n",
    "script_dynamic_control_flow_func = torch.jit.script(dynamic_control_flow)\n",
    "print(script_dynamic_control_flow_func.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "`torch.jit.script` relies on type hint to work properly. For example, any unannotated input argument is treated as a `torch.Tensor` while you can annotate some input arguments to be python types to make `torch.jit.script` work as intended.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combined Usage of JIT and Vectorizing Map in PyTorch\n",
    "\n",
    "Based on the introductions above, when `torch.jit.trace` and `torch.jit.script` are used in combination with `torch.vmap`, coordination is required due to compatibility considerations.\n",
    "\n",
    "The figure below illustrates the relationship between `torch.jit.script`, `torch.jit.trace`, and `torch.vmap`, highlighting their mutual invocation paths. If module A invokes module B, it implies that B can be called by A.\n",
    "\n",
    "```{image} /_static/jit_vmap.png\n",
    ":alt: JIT introduction\n",
    ":align: center\n",
    "```\n",
    "\n",
    "For detailed usage of JIT and vectorizing map on PyTorch, please refer to the official PyTorch documentation for [TorchScript](#https://pytorch.org/docs/stable/jit.html) and [`torch.vmap`](#https://pytorch.org/docs/stable/generated/torch.vmap.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specific Optimization Support in EvoX\n",
    "\n",
    "Within EvoX, most functions are defined inside classes, particularly subclasses of [`ModuleBase`](#evox.core.module.ModuleBase). To provide more comprehensive optimization supports, EvoX offers specific enhancements."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using JIT to Subclasses of [`ModuleBase`](#evox.core.module.ModuleBase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For better understanding of this part, we need to explain three important functions in EvoX: [`jit_class`](#evox.core.module.jit_class), [`vmap`](#evox.core.jit_util.vmap) and [`jit`](l#evox.core.jit_util.jit)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### [`jit_class`](#evox.core.module.jit_class) Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[`jit_class`](#evox.core.module.jit_class) is a helper function used to Just-In-Time (JIT) script of [`torch.jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.script.html) or trace ([`torch.jit.trace_module`](https://pytorch.org/docs/stable/generated/torch.jit.trace_module.html#torch-jit-trace-module)) all member methods of the input class. \n",
    "\n",
    "[`jit_class`](#evox.core.module.jit_class) has two parameters:\n",
    "\n",
    "- `cls`: the original class whose member methods are to be lazy JIT.\n",
    "- `trace`: whether to trace the module or to script the module. Default to `False`.\n",
    "\n",
    "```{note}\n",
    "1. In many cases, it is not necessary to wrap your custom algorithms or problems with [`jit_class`](#evox.core.module.jit_class), the workflow(s) will do the trick for you.\n",
    "2. With `trace=True`, all the member functions are effectively modified to return `self` additionally since side-effects cannot be traced. If you want to preserve the side effects, please set `trace=False` and use the `use_state` function to wrap the member method to generate pure-functional (the `use_state` function will be explained in the next part).\n",
    "3. Similarly, all module-wide operations like `self.to(...)` can only returns the unwrapped module, which may not be desired. Since most of them are in-place operations, a simple `module.to(...)` can be used instead of `module = module.to(...)`.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "#### [`vmap`](#evox.core.jit_util.vmap) Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[`vmap`](#evox.core.jit_util.vmap) function vectorized map the given function to its mapped version. Based on [`torch.vmap`](https://pytorch.org/docs/main/generated/torch.vmap.html), we made many improvements, and you can see [`torch.vmap`](https://pytorch.org/docs/main/generated/torch.vmap.html) for more information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### [`jit`](#evox.core.jit_util.jit) Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[`jit`](#evox.core.jit_util.jit) compile the given `func` via [`torch.jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.script.html) (`trace=True`) or [`torch.jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) (`trace=False`).\n",
    "\n",
    "  This function wrapper effectively deals with nested JIT and vector map (`vmap`) expressions like `jit(func1)` -> `vmap` -> `jit(func2)`, preventing possible errors.\n",
    "\n",
    "```{note}\n",
    "1. With `trace=True`, `torch.jit.trace` cannot use SAME example input arguments for function of DIFFERENT parameters,e.g., you cannot pass `tensor_a, tensor_a` to `torch.jit.trace`d version of `f(x: torch.Tensor, y: torch.Tensor)`.\n",
    "2. With `trace=False`, `torch.jit.script` cannot contain `vmap` expressions directly, please wrap them with `jit(..., trace=True)` or `torch.jit.trace`.\n",
    "```\n",
    "\n",
    "In the [Working with Module in EvoX](#/guide/developer/1-modulebase), we have briefly introduced some rules about the methods inside a subclass of the [`ModuleBase`](#evox.core.module.ModuleBase) . Now that [`jit_class`](#evox.core.module.jit_class), [`vmap`](#evox.core.jit_util.vmap) and [`jit`](#evox.core.jit_util.jit) have been explained, we will explain more rules and provide some specific hints."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Definition of Static Methods Inside the Subclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inside the subclass, static methods to be JIT shall be defined like:\n",
    "\n",
    "```Python\n",
    "# Import Pytorch\n",
    "import torch\n",
    "\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase, jit\n",
    "\n",
    "# Set an module inherited from the ModuleBase class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...\n",
    "    \n",
    "    # One example of the static method defined in a Module \n",
    "    @jit\n",
    "    def func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "        return x + y\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Definition of Non-static Methods Inside the Subclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If a method with **Python dynamic control flows** like `if` were to be JIT, a separated static method with `jit(..., trace=False)` or `torch.jit.script_if_tracing` shall be used:\n",
    "\n",
    "```python\n",
    "# Import Pytorch\n",
    "import torch\n",
    "\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase, jit\n",
    "\n",
    "# Set an module inherited from the ModuleBase class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...\n",
    "    \n",
    "    # An example of one method with python dynamic control flows like \"if\"\n",
    "    # The method using jit(..., trace=False)\n",
    "    @partial(jit, trace=False)\n",
    "    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:\n",
    "        if x.flatten()[0] > threshold:\n",
    "            return torch.sin(x)\n",
    "        else:\n",
    "            return torch.tan(x)\n",
    "        \n",
    "    # The method to be JIT   \n",
    "    @jit\n",
    "    def jit_func(self, p: torch.Tensor) -> torch.Tensor:\n",
    "        return ExampleModule.static_func(p, self.threshold)\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```\n",
    "\n",
    "```{note}\n",
    "Dynamic control flow in Python refers to control structures that change dynamically based on conditions at runtime.\n",
    "`if...elif...else` Conditional Statements, `for`loop and `while` loop are all dynamic control flows. If you have to use them when defining non-static Methods inside the subclass of [`ModuleBase`](#evox.core.module.ModuleBase), please follow the above rule. \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Invocation of External Methods Inside the Subclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inside the subclass, external JIT methods can be invocated by the class methods to be JIT:\n",
    "\n",
    "```python\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase\n",
    "\n",
    "# One example of the JIT method defined outside the module \n",
    "@jit\n",
    "def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "    return x + y\n",
    "\n",
    "# Set an module inherited from the ModuleBase class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...\n",
    "\n",
    "    # The internal method using jit(..., trace=False)\n",
    "    @partial(jit, trace=False)\n",
    "    def static_func(x: torch.Tensor, threshold: float) -> torch.Tensor:\n",
    "\n",
    "        \n",
    "    # The internal static method to be JIT   \n",
    "    @jit\n",
    "    def jit_func(self, p: torch.Tensor) -> torch.Tensor:\n",
    "        return external_func(p, p)\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Automatically JIT for the Subclass Used with `jit_class`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[`ModuleBase`](#evox.core.module.ModuleBase)  and its subclasses are usually used with [`jit_class`](#evox.core.module.jit_class) to automatically JIT all non-magic member methods:\n",
    "\n",
    "```python\n",
    "# Import Pytorch\n",
    "import torch\n",
    "\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase, jit_class\n",
    "\n",
    "@jit_class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...\n",
    "    \n",
    "    # This function will be automatically JIT\n",
    "    def automatically_JIT_func1(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        pass\n",
    "\n",
    "    # Use `torch.jit.ignore` to disable JIT and leave this function as Python callback\n",
    "    @torch.jit.ignore\n",
    "    def no_JIT_func2(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        # you can implement pure Python logic here\n",
    "        pass\n",
    "\n",
    "    # JIT functions can invoke other JIT functions as well as non-JIT functions\n",
    "    def automatically_JIT_func3(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        y = self.automatically_JIT_func1(x)\n",
    "        z = self.no_JIT_func2(x)\n",
    "        pass\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Invocation of External Vmap-wrapped Methods Inside the Subclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inside the subclass, external vmap-wrapped methods can be invocated by the class methods to be JIT:\n",
    "\n",
    "```Python\n",
    "# Import Pytorch\n",
    "import torch\n",
    "\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase, jit, vmap\n",
    "\n",
    "\n",
    "# The method to be vmap-wrapped\n",
    "def external_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "    return x + y.sum()\n",
    "\n",
    "external_vmap_func = vmap(external_func, in_dims=1, out_dims=1)\n",
    "\n",
    "\n",
    "# Set an module inherited from the ModuleBase class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...    \n",
    "    \n",
    "    # The internal class method to be JIT   \n",
    "    @jit\n",
    "    def jit_func(self, p: torch.Tensor) -> torch.Tensor:\n",
    "        return external_vmap_func(p, p)\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```\n",
    "\n",
    "```{note}\n",
    "If method A invokes vmap-wrapped method B, then A and all methods invoke method A can not be vmap-wrapped again.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Internal Vmap-wrapped Methods Inside the Subclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inside the subclass, internal vmap-wrapped methods can be  JIT by using the [`trace_impl`](#evox.core.module.trace_impl):\n",
    "\n",
    "```Python\n",
    "# Import Pytorch\n",
    "import torch\n",
    "\n",
    "# Import the ModuleBase class from EvoX\n",
    "from evox.core import ModuleBase, jit, vmap, trace_impl\n",
    "\n",
    "\n",
    "# Set an module inherited from the ModuleBase class\n",
    "class ExampleModule(ModuleBase):\n",
    "    \n",
    "    ...    \n",
    "    \n",
    "    # The internal vmap-wrapped class method to be JIT   \n",
    "    @jit\n",
    "    def jit_vmap_func(self, p: torch.Tensor) -> torch.Tensor:\n",
    "        \n",
    "        # The original method\n",
    "        # We can not vmap it\n",
    "    \tdef func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "    \t\treturn x + y\n",
    "        \n",
    "        \n",
    "        # The method to be vmap-wrapped\n",
    "        # We need to use trace_impl to rewrite the original method\n",
    "        @trace_impl(func)\n",
    "    \tdef trace_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
    "    \t\tpass\n",
    "        \n",
    "        return vmap(func, in_dims=1, out_dims=1, trace=False)(p, p)\n",
    "    \n",
    "    ...\n",
    "    \n",
    "```\n",
    "\n",
    "```{note}\n",
    "If a class method use [`trace_impl`](#evox.core.module.trace_impl), it will be only available in the trace mode. More details about `trace_impl` will be shown in the next part.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using `@trace_impl` and `@vmap_impl`\n",
    "\n",
    "When designing a function or method, you may not always consider whether it is `JIT`-compatible. However, this property becomes crucial in specific scenarios, such as solving Hyperparameter Optimization (HPO) problems. For more details on deploying HPO with EvoX, refer to [Efficient HPO with EvoX](#/guide/user/3-hpo).\n",
    "\n",
    "A typical characteristic of such problems is that only certain parts of the algorithm need modification—for instance, the `step` method of an algorithm. This allows you to avoid rewriting the entire algorithm. In such cases, you can use the `@trace_impl` or `@vmap_impl` decorator to rewrite the function as a trace-JIT-time or vmap-JIT-time proxy for the specified `target` method.\n",
    "\n",
    "The decorators [`@trace_impl`](#trace_impl) and [`@vmap_impl`](#vmap_impl) accept a single input parameter: the target method invoked when not tracing/vmapping JIT. These decorators are applicable **only** to member methods within a `jit_class`.\n",
    "\n",
    "Since the annotated function serves as a rewritten version of the target function, it must maintain identical input/output signatures (e.g., number and types of arguments). Otherwise, the resulting behavior is undefined.\n",
    "\n",
    "If the annotated function is intended for use with `vmap`, it must satisfy three additional constraints:\n",
    "\n",
    "1. **No In-Place Operations on Attributes:**\n",
    "   The algorithm must not include methods that perform in-place operations on its attributes.\n",
    "\n",
    "```python\n",
    "class ExampleAlgorithm(Algorithm):\n",
    "    def __init__(self, ...):\n",
    "        self.pop = torch.rand(10, 10)  # Attribute of the algorithm\n",
    "\n",
    "    def step_in_place(self):  # Method with in-place operations\n",
    "        self.pop.copy_(pop)\n",
    "\n",
    "    def step_out_of_place(self):  # Method without in-place operations\n",
    "        self.pop = pop\n",
    "```\n",
    "\n",
    "2. **Avoid Python Control Flow:**\n",
    "   The code logic must not rely on Python control flow structures. To handle Python control flow, use [`TracingCond`](#TracingCond), [`TracingWhile`](#TracingWhile), and [`TracingSwitch`](#TracingSwitch).\n",
    "\n",
    "```python\n",
    "@jit_class\n",
    "class ExampleAlgorithm(Algorithm):\n",
    "    def __init__(self, pop_size, ...):\n",
    "        super().__init__()\n",
    "        self.pop = torch.rand(pop_size, pop_size)\n",
    "\n",
    "    def strategy_1(self):  # One update strategy\n",
    "        new_pop = self.pop * self.pop\n",
    "        self.pop = new_pop\n",
    "\n",
    "    def strategy_2(self):  # Another update strategy\n",
    "        new_pop = self.pop + self.pop\n",
    "        self.pop = new_pop\n",
    "\n",
    "    def step(self):\n",
    "        control_number = torch.rand()\n",
    "        if control_number < 0.5:  # Conditional control\n",
    "            self.strategy_1()\n",
    "        else:\n",
    "            self.strategy_2()\n",
    "\n",
    "    @trace_impl(step)  # Rewrite step function for vmap support\n",
    "    def trace_step_without_operations_to_self(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        pop = pop * self.hp[0]\n",
    "        control_number = torch.rand()\n",
    "        cond = control_number < 0.5\n",
    "        branches = (self.strategy_1, self.strategy_2)\n",
    "        state, names = self.prepare_control_flow(*branches)  # Utilize state to track self.pop\n",
    "        _if_else_ = TracingCond(*branches)\n",
    "        state = _if_else_.cond(state, cond, pop)\n",
    "        self.after_control_flow(state, *names)\n",
    "\n",
    "    @trace_impl(step)\n",
    "    def trace_step_with_operations_to_self(self):\n",
    "        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)\n",
    "        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]\n",
    "        pop = pop * self.hp[0]\n",
    "        control_number = torch.rand()\n",
    "        cond = control_number < 0.5\n",
    "        _if_else_ = TracingCond(lambda p: p * p, lambda p: p + p)  # No need to track self.pop\n",
    "        pop = _if_else_.cond(cond, pop)\n",
    "        self.pop = pop\n",
    "```\n",
    "\n",
    "3. **Avoid In-Place Operations on `self`:**\n",
    "   Vectorized map in-place operations on `self` are not well-defined and cannot be compiled. Even if it is compiled successfully, you can still silently get incorrect results.\n",
    "\n",
    "### Using `use_state`\n",
    "\n",
    "[`use_state`](#use_state) transforms a given stateful function (which performs in-place alterations on `nn.Module`s) into a pure-functional version that receives an additional `state` parameter (of type `Dict[str, torch.Tensor]`) and returns the altered state.\n",
    "\n",
    "The input `func` is the stateful function to be transformed or its generator function, and `is_generator` specifies whether `func` is a function or a function generator (e.g., a lambda that returns the stateful function). It defaults to `True`.\n",
    "\n",
    "Here is a simple example:\n",
    "\n",
    "```python\n",
    "@jit_class\n",
    "class Example(ModuleBase):\n",
    "    def __init__(self, threshold=0.5):\n",
    "        super().__init__()\n",
    "        self.threshold = threshold\n",
    "        self.sub_mod = nn.Module()\n",
    "        self.sub_mod.buf = nn.Buffer(torch.zeros(()))\n",
    "\n",
    "    def h(self, q: torch.Tensor) -> torch.Tensor:\n",
    "        if q.flatten()[0] > self.threshold:\n",
    "            x = torch.sin(q)\n",
    "        else:\n",
    "            x = torch.tan(q)\n",
    "        x += self.g(x).abs()\n",
    "        x *= x.shape[1]\n",
    "        self.sub_mod.buf = x.sum()\n",
    "        return x\n",
    "\n",
    "    @trace_impl(h)\n",
    "    def th(self, q: torch.Tensor) -> torch.Tensor:\n",
    "        x += self.g(x).abs()\n",
    "        x *= x.shape[1]\n",
    "        self.sub_mod.buf = x.sum()\n",
    "        return x\n",
    "\n",
    "    def g(self, p: torch.Tensor) -> torch.Tensor:\n",
    "        x = torch.cos(p)\n",
    "        return x * p.shape[0]\n",
    "\n",
    "fn = use_state(lambda: t.h, is_generator=True)\n",
    "jit_fn = jit(fn, trace=True, lazy=True)\n",
    "results = jit_fn(fn.init_state(), torch.rand(10, 1))\n",
    "print(results)  # ({\"self.sub_mod.buf\": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))\n",
    "\n",
    "# IN-PLACE update all relevant variables using the given state\n",
    "fn.set_state(results[0])\n",
    "```\n",
    "\n",
    "### Using `core._vmap_fix`\n",
    "\n",
    "The module [`_vmap_fix`](#_vmap_fix) provides useful functions. After the automatic import, `_vmap_fix` enables `torch.vmap` to be correctly traced by `torch.jit.trace`, while resolving issues such as random number handling that couldn't be properly traced during the `vmap` process. It also provides the `debug_print` function, which allows dynamic printing of Tensor values during both `vmap` and tracing.\n",
    "\n",
    "Detailed information can be found in the [`_vmap_fix`](#_vmap_fix) documentation.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
